CS236605: Deep Learning on Computational Accelerators

Homework Assignment 3

Faculty of Computer Science, Technion.

Submitted by:

# Name Id email
Student 1 Yun-Hsiang Chen 921180238 ctjoychen@gmail.com
Student 2 Shili Wang 921180436 sallyw0727@gmail.com

Introduction

In this assignment we'll learn to generate text with a deep multilayer RNN network based on GRU cells. Then we'll focus our attention on image generation and implement two different generative models: A variational autoencoder and a generative adversarial network.

General Guidelines

  • Please read the getting started page on the course website. It explains how to setup, run and submit the assignment.
  • This assignment requires running on GPU-enabled hardware. Please read the course servers usage guide. It explains how to use and run your code on the course servers to benefit from training with GPUs.
  • The text and code cells in these notebooks are intended to guide you through the assignment and help you verify your solutions. The notebooks do not need to be edited at all (unless you wish to play around). The only exception is to fill your name(s) in the above cell before submission. Please do not remove sections or change the order of any cells.
  • All your code (and even answers to questions) should be written in the files within the python package corresponding the assignment number (hw1, hw2, etc). You can of course use any editor or IDE to work on these files.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bb}[1]{\boldsymbol{#1}} $$

Part 1: Sequence Models

In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Text generation with a char-level RNN

Obtaining the corpus

Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.

In [2]:
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')

def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
    pathlib.Path(out_path).mkdir(exist_ok=True)
    out_filename = os.path.join(out_path, os.path.basename(url))
    
    if os.path.isfile(out_filename) and not force:
        print(f'Corpus file {out_filename} exists, skipping download.')
    else:
        print(f'Downloading {url}...')
        with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
        print(f'Saved to {out_filename}.')
    return out_filename
    
corpus_path = download_corpus()
Corpus file /home/y.chen/.pytorch-datasets/shakespeare.txt exists, skipping download.

Load the text into memory and print a snippet:

In [3]:
with open(corpus_path, 'r') as f:
    corpus = f.read()

print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL

by William Shakespeare

Dramatis Personae

  KING OF FRANCE
  THE DUKE OF FLORENCE
  BERTRAM, Count of Rousillon
  LAFEU, an old lord
  PAROLLES, a follower of Bertram
  TWO FRENCH LORDS, serving with Bertram

  STEWARD, Servant to the Countess of Rousillon
  LAVACHE, a clown and Servant to the Countess of Rousillon
  A PAGE, Servant to the Countess of Rousillon

  COUNTESS OF ROUSILLON, mother to Bertram
  HELENA, a gentlewoman protected by the Countess
  A WIDOW OF FLORENCE.
  DIANA, daughter to the Widow

  VIOLENTA, neighbour and friend to the Widow
  MARIANA, neighbour and friend to the Widow

  Lords, Officers, Soldiers, etc., French and Florentine  

SCENE:
Rousillon; Paris; Florence; Marseilles

ACT I. SCENE 1.
Rousillon. The COUNT'S palace

Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black

  COUNTESS. In delivering my son from me, I bury a second husband.
  BERTRAM. And I in going, madam, weep o'er my father's death anew;
    but I must attend his Majesty's command, to whom I am now in
    ward, evermore in subjection.
  LAFEU. You shall find of the King a husband, madam; you, sir, a
    father. He that so generally is at all times good must of
    

Data Preprocessing

The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.

TODO: Implement the char_maps() function in the hw3/charnn.py module.

In [4]:
import hw3.charnn as charnn

char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)

test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'\n': 0, ' ': 1, '!': 2, '"': 3, '$': 4, '&': 5, "'": 6, '(': 7, ')': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '<': 24, '?': 25, 'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35, 'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45, 'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51, '[': 52, ']': 53, '_': 54, 'a': 55, 'b': 56, 'c': 57, 'd': 58, 'e': 59, 'f': 60, 'g': 61, 'h': 62, 'i': 63, 'j': 64, 'k': 65, 'l': 66, 'm': 67, 'n': 68, 'o': 69, 'p': 70, 'q': 71, 'r': 72, 's': 73, 't': 74, 'u': 75, 'v': 76, 'w': 77, 'x': 78, 'y': 79, 'z': 80, '}': 81, '\ufeff': 82}

Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.

TODO: Implement the remove_chars() function in the hw3/charnn.py module.

In [5]:
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')

# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 6347669 chars

The next thing we need is an embedding of the chracters. An embedding is a representation of each token from the sequence as a tensor. For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index corresponding to that specific char.

TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.

In [6]:
# Wrap the actual embedding functions for calling convenience
def embed(text):
    return charnn.chars_to_onehot(text, char_to_idx)

def unembed(embedding):
    return charnn.onehot_to_chars(embedding, idx_to_char)

text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))

test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
   
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int8)

Dataset Creation

We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.

We will split our corpus into shorter sequences of length S chars (try to think why; see question below). Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence. For each sample, we'll also need a label. This is simple another sequence, shifted by one char so that the label of each char is the next char in the corpus.

TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.

In [7]:
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)

# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')

# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))

# Test content
for _ in range(1000):
    # random sample
    i = np.random.randint(num_samples, size=(1,))[0]
    # Compare to corpus
    test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
    # Compare to labels
    sample_text = unembed(samples[i])
    label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
    test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
    
print(f'sample 100 as text:\n{unembed(samples[100])}')
samples shape: torch.Size([99182, 64, 78])
labels shape: torch.Size([99182, 64])
sample 100 as text:
nity, though valiant in the
    defence, yet is weak. Unfold to 

As usual, instead of feeding one sample as a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.

Let's use the standard PyTorch Dataset/DataLoader combo. Luckily for the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label) from the samples and labels tensors we created above.

In [8]:
import torch.utils.data

# Create DataLoader returning batches of samples.
batch_size = 32

ds_corpus = torch.utils.data.TensorDataset(samples, labels)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, shuffle=False)

Let's see what that gives us:

In [9]:
print(f'num batches: {len(dl_corpus)}')

x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch sample: {x0.shape}')
print(f'shape of a batch label: {y0.shape}')
num batches: 3100
shape of a batch sample: torch.Size([32, 64, 78])
shape of a batch label: torch.Size([32, 64])

Model Implementation

Finally, our data set is ready so we can focus on our model.

We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.

The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.

Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as

$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$

The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}

\end{bmatrix}

\begin{cases} \mat{X} & \mathrm{if} ~k = 1~ \\ \mathrm{dropout}_p \left( \begin{bmatrix} {\vec{h}_1}^{[k-1]} \\ \vdots \\ {\vec{h}_S}^{[k-1]} \end{bmatrix} \right) & \mathrm{if} ~1 < k \leq L+1~ \end{cases}. $$

The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$

and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$

Notes:

  • $t\in[1,S]$ is the timestep, i.e. the current position within the sequence of each sample.
  • $\vec{x}_t^{[k]}$ is the input of layer $k$ at timestep $t$, respectively.
  • The outputs of the last layer $\vec{y}_t^{[L]}$, are the predicted next characters for every input char. These are similar to class scores in classification tasks.
  • The hidden states at the last timestep, $\vec{h}_S^{[k]}$, are the final hidden state returned from the model.
  • $\sigma(\cdot)$ is the sigmoid function, i.e. $\sigma(\vec{z}) = 1/(1+e^{-\vec{z}})$ which returns values in $(0,1)$.
  • $\tanh(\cdot)$ is the hyperbolic tangent, i.e. $\tanh(\vec{z}) = (e^{2\vec{z}}-1)/(e^{2\vec{z}}+1)$ which returns values in $(-1,1)$.
  • $\vec{h_t}^{[k]}$ is the hidden state of layer $k$ at time $t$. This can be thought of as the memory of that layer.
  • $\vec{g_t}^{[k]}$ is the candidate hidden state for time $t+1$.
  • $\vec{z_t}^{[k]}$ is known as the update gate. It combines the previous state with the input to determine how much the current state will be combined with the new candidate state. For example, if $\vec{z_t}^{[k]}=\vec{1}$ then the current input has no effect on the output.
  • $\vec{r_t}^{[k]}$ is known as the reset gate. It combines the previous state with the input to determine how much of the previous state will affect the current state candidate. For example if $\vec{r_t}^{[k]}=\vec{0}$ the previous state has no effect on the current candidate state.

Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).

Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.

TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.

Notes:

  • You'll need to handle input batches now. The math is identical to the above, but all the tensors will have an extra batch dimension as their first dimension.
  • Use the diagram above to help guide your implementation. It will help you visualize what shapes to returns where, etc.
In [10]:
in_dim = vocab_len
h_dim = 256
n_layers = 2
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)

# Test forward pass
y, h = model(x0.to(dtype=torch.float))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')

test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2) 
MultilayerGRU(
  (w_xz_0): Linear(in_features=78, out_features=256, bias=False)
  (w_xr_0): Linear(in_features=78, out_features=256, bias=False)
  (w_xg_0): Linear(in_features=78, out_features=256, bias=False)
  (w_hz_0): Linear(in_features=256, out_features=256, bias=True)
  (w_hr_0): Linear(in_features=256, out_features=256, bias=True)
  (w_hg_0): Linear(in_features=256, out_features=256, bias=True)
  (dropout_0): Dropout(p=0)
  (w_xz_1): Linear(in_features=256, out_features=256, bias=False)
  (w_xr_1): Linear(in_features=256, out_features=256, bias=False)
  (w_xg_1): Linear(in_features=256, out_features=256, bias=False)
  (w_hz_1): Linear(in_features=256, out_features=256, bias=True)
  (w_hr_1): Linear(in_features=256, out_features=256, bias=True)
  (w_hg_1): Linear(in_features=256, out_features=256, bias=True)
  (dropout_1): Dropout(p=0)
  (w_y): Linear(in_features=256, out_features=78, bias=True)
)
y.shape=torch.Size([32, 64, 78])
h.shape=torch.Size([32, 2, 256])

Generating text by sampling

Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t; \vec{h}_t).$$

Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.

The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.

To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$

A low $T$ will result in less uniform distributions and vice-versa.

TODO: Implement the hot_softmax() function in the hw3/charnn.py module.

In [11]:
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))

for t in reversed([0.3, 0.5, 1.0, 100]):
    ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()

uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))

TODO: Implement the generate_from_model() function in the hw3/charnn.py module.

In [12]:
for _ in range(3):
    text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
    print(text)
    test.assertEqual(len(text), 50)
foobar)R?"qOdjik-L"17c9nf"-pMnbyt&a2epe(f:U]Sj,rd0
foobar;RG;gp"
)T[7Yo[GRy88Gdd;Oy8:KYDd4Fj'WmtjGMng
foobarHL2Yl-ujwnmWfE25dAH3OXtW6:3QJp[eL-.HXqLDa8zC

Training

To train such a model, we'll calculate the loss at each time step by comparing the predicted char to the actual char from our label. We can use cross entropy since per char it's similar to a classification problem. We'll then sum the losses over the sequence and back-propagate the gradients though time. Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times, so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.

As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.

For a generative model such as this, overfitting is slightly trickier than for for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.

Let's create a tiny dataset to memorize.

In [13]:
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size=1, shuffle=False)

# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
    subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":

TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

Now let's implement the first part of our training code.

TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module. Note: Think about how to correctly handle the hidden state of the model between batches and epochs (for this specific task, i.e. text generation).

In [14]:
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer

torch.manual_seed(42)

lr = 0.01
num_epochs = 500

in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

for epoch in range(num_epochs):
    epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
    
    # Every X epochs, we'll generate a sequence starting from the first char in the first sequence
    # to visualize how/if/what the model is learning.
    if epoch == 0 or (epoch+1) % 25 == 0:
        avg_loss = np.mean(epoch_result.losses)
        accuracy = np.mean(epoch_result.accuracy)
        print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
        
        generated_sequence = charnn.generate_from_model(model, subset_text[0],
                                                        seq_len*(subset_end-subset_start),
                                                        (char_to_idx,idx_to_char), T=0.1)
        # Stop if we've successfully memorized the small dataset.
        print(generated_sequence)
        if generated_sequence == subset_text:
            break

# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.956, Accuracy = 18.36%
Tnn                                                                                                                                                                                                                                                             

Epoch #25: Avg. loss = 0.214, Accuracy = 97.66%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    Faith, yes:
    Strangers and foes sousder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HELERAM. I pray you, stay not, but in haste to horse

Epoch #50: Avg. loss = 0.009, Accuracy = 100.00%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

OK, so training works - we can memorize a short sequence. Next on the agenda is to split our full dataset into a training and test sets of batched sequences.

In [15]:
# Full dataset definition
vocab_len = len(char_to_idx)
seq_len = 64
batch_size = 256
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)

samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)

ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=False, drop_last=True)

ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=True)

print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test:  {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
Train: 348 batches, 5701632 chars
Test:   38 batches,  622592 chars

We'll now train a much larger model on our large dataset. You'll need a GPU for this part.

The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.

Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.

In [16]:
# Full training definition
lr = 0.001
num_epochs = 50

in_dim = out_dim = vocab_len
hidden_dim = 512
n_layers = 3
dropout = 0.5
checkpoint_file = 'checkpoints/rnn'
max_batches = 300
early_stopping = 5

model = charnn.MultilayerGRU(in_dim, hidden_dim, out_dim, n_layers, dropout)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

TODO:

  • Implement the fit() method of the Trainer class. You can reuse the implementation from HW2, but make sure to implement early stopping and checkpoints.
  • Implement the test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.
  • Run the following block to train.
In [17]:
from cs236605.plot import plot_fit

def post_epoch_fn(epoch, test_res, train_res, verbose):
    # Update learning rate
    scheduler.step(test_res.accuracy)
    # Sample from model to show progress
    if verbose:
        start_seq = "ACT I."
        generated_sequence = charnn.generate_from_model(
            model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
        )
        print(generated_sequence)

# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    saved_state = torch.load(checkpoint_file_final, map_location=device)
    model.load_state_dict(saved_state['model_state'])
else:
    try:
        # Print pre-training sampling
        print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))

        fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=max_batches,
                              post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
                              checkpoints=checkpoint_file, print_every=1)
        
        fig, axes = plot_fit(fit_res)
    except KeyboardInterrupt as e:
        print('\n *** Training interrupted by user')
ACT I."VU0P,6rT.5:k w
5qPM(v'yC6!Ml!OEt4ERxDX rC4,bVZJops]
ram8"8fBmkZKOIv
bna5'ZHVdZ.LzvUeEzO98GMq0
*** Loading checkpoint file checkpoints/rnn.pt
--- EPOCH 1/50 ---
train_batch (Avg. Loss 1.271, Accuracy 61.7): 100%|██████████| 300/300 [01:08<00:00,  4.38it/s]
test_batch (Avg. Loss 1.806, Accuracy 50.7): 100%|██████████| 38/38 [00:03<00:00, 11.71it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 1
ACT I. SCENE I.
The palace

Enter CHARINIUS, and PATROLLUS

  CLOWN. I have a thousand sparing-day, 
--- EPOCH 2/50 ---
train_batch (Avg. Loss 1.254, Accuracy 62.1): 100%|██████████| 300/300 [01:08<00:00,  4.39it/s]
test_batch (Avg. Loss 1.805, Accuracy 50.8): 100%|██████████| 38/38 [00:03<00:00, 11.68it/s]
ACT I. Sir, I could be shortly see your sight;
    And so much as I do not do my sword and a deed of
--- EPOCH 3/50 ---
train_batch (Avg. Loss 1.238, Accuracy 62.5): 100%|██████████| 300/300 [01:08<00:00,  4.40it/s]
test_batch (Avg. Loss 1.799, Accuracy 50.9): 100%|██████████| 38/38 [00:03<00:00, 11.74it/s]
ACT I.
A state of the sons of the world, I would make thee a fair beard,
    And make the shame of s
--- EPOCH 4/50 ---
train_batch (Avg. Loss 1.222, Accuracy 62.9): 100%|██████████| 300/300 [01:08<00:00,  4.39it/s]
test_batch (Avg. Loss 1.792, Accuracy 51.0): 100%|██████████| 38/38 [00:03<00:00, 11.73it/s]
ACT I. SCENE I.
Venice.

                                                                           
--- EPOCH 5/50 ---
train_batch (Avg. Loss 1.209, Accuracy 63.3): 100%|██████████| 300/300 [01:05<00:00,  4.65it/s]
test_batch (Avg. Loss 1.792, Accuracy 51.3): 100%|██████████| 38/38 [00:03<00:00, 11.88it/s]
ACT I. SCENE 1.
Troy. I say, if I do not speak with you.
  MALVOLIO. 'Tis a man to himself and all t
--- EPOCH 6/50 ---
train_batch (Avg. Loss 1.196, Accuracy 63.6): 100%|██████████| 300/300 [01:04<00:00,  4.72it/s]
test_batch (Avg. Loss 1.795, Accuracy 51.4): 100%|██████████| 38/38 [00:03<00:00, 12.47it/s]
ACT I. SCENE I.
The tent of all the walls
    Shall be so fair from the noise of my love.
          
--- EPOCH 7/50 ---
train_batch (Avg. Loss 1.186, Accuracy 63.9): 100%|██████████| 300/300 [01:03<00:00,  4.69it/s]
test_batch (Avg. Loss 1.801, Accuracy 51.5): 100%|██████████| 38/38 [00:03<00:00, 12.55it/s]

Generating a work of art

Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.

TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.

In [18]:
import hw3.answers

start_seq, temperature = hw3.answers.part1_generation_params()

generated_sequence = charnn.generate_from_model(
    model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)

print(generated_sequence)
HAMLET. What shall I do and see the streets of the world to say that I am a strait is that the word of many man in my life.
    I will not be a strain of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many man of many m

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [19]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Why do we split the corpus into sequences instead of training on the whole text?

In [20]:
display_answer(hw3.answers.part1_q1)

Your answer:

"We will split our corpus into shorter sequences of length S chars. Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence." Because if training on the whole dataset, the dataset is just 1 sample which doesn't contain enough information for the network to train.

Question 2

How is it possible that the generated text clearly shows memory longer than the sequence length?

In [21]:
display_answer(hw3.answers.part1_q2)

Your answer:

The memory longer than the sequence length depends on the hidden layers. The past information is stored in hidden layer to generate later character and some information therefore keeps long and pass forward.

Question 3

Why are we not shuffling the order of batches when training?

In [22]:
display_answer(hw3.answers.part1_q3)

Your answer:

The reason for not shuffling is because of time dependence of the original text. The former texts have impact on the later texts. If shuffling while training, such time dependence will be missed.

Question 4

  1. Why do we lower the temperature for sampling (compared to the default of $1.0$ when training)?
  2. What happens when the temperature is very high and why?
  3. What happens when the temperature is very low and why?
In [23]:
display_answer(hw3.answers.part1_q4)

Your answer:

1.Temperature sampling works by increasing the probability of the most likely words before sampling. In order to make RNN sensitive to more samples at the beginning, we need to contain the information of the most likely predicted label of each sample. For T=1, the freezing function is just the identity function. The lower the temperature, the more expected rewards affect the probability.

2.For high temperatures (τ→∞), all samples have nearly the same probability. As a result, the generated text becomes more diverse and displays greater linguistic variety.

3.For a low temperature (τ→0+), the probability of the sample with the highest expected reward tends to 1. As a result, the generated text is grammatically correct.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 2: Variational Autoencoder

In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from it's latent space. We'll implement and train a VAE and use it to generate new images.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labels faces of famous individuals.

We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)

However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/y.chen/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/y.chen/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/y.chen/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

The Variational Autoencoder

An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a neural net with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a neural net with parameters $\bb{\beta}$).

While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.

We define, in Baysean terminology,

  • The prior distribution $p(\bb{Z})$ on points in the latent space.
  • The likelihood distribution of a sample $\bb{X}$ given a latent-space representation: $p(\bb{X}|\bb{Z})$.
  • The posterior distribution of points in the latent spaces given a specific instance: $p(\bb{Z}|\bb{X})$.
  • The evidence distribution $p(\bb{X})$ which is the distribution of the instance space due to the generative process.

To create our variational decoder we'll further specify:

  • A parametric likelihood distribution, $p _{\bb{\beta}}(\bb{X} | \bb{z}) = \mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$. The interpretation is that given a latent $\bb{z}$, we map it to a point normally distributed around the point calculated by our decoder neural network. Note that here $\sigma^2$ is a hyperparameter while $\vec{\beta}$ represents the network parameters.
  • A fixed latent-space prior distribution of $p(\bb{Z}) = \mathcal{N}(\bb{0},\bb{I})$.

This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.

Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) \sim \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder neural network, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})$.

To train a VAE model, we would like to maximize the evidence, $p(\bb{X})$, because $ p(\bb{X}) = \int p(\bb{X}|{\bb{z}})p(\bb{z})d\bb{z} $ thus maximizing the likelihood of generated instances from over the entire latent space.

The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. As we saw in the lecture, this expectation is intractable, but we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO"):

$$ \log p(\bb{X}) \ge \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }( \log p _{\bb{\beta}}(\bb{X} | \bb{z}) )

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{X})\,\left|\, p(\bb{Z} )\right.\right) $$

where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.

Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as

$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left| \bb{x}- \Psi {\bb{\beta}}\left( \bb{\mu} {\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} {\bb{\alpha}}(\bb{x}) \bb{u} \right) \right| _2^2 \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

Model Implementation

Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).

First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input. which will be our latent space representation.

TODO: Implement the EncoderCNN class in the hw3/autoencoder.py module. Implement any CNN architecture you like. If you need "architecture inspiration" you can see e.g. this or this paper.

In [6]:
import hw3.autoencoder as autoencoder

in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)

h = encoder_cnn(x0)
print(h.shape)

test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(256, 1024, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): ReLU()
  )
)
torch.Size([1, 1024, 4, 4])

Now let's implement the CNN part of the Decoder. Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced by your EncoderCNN and output an image of the same dimensions as the Encoder's input was. This should be a CNN which is a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc. Consult the documentation of ConvTranspose2D to figure out how to reverse your convolutional layers in terms of input and output dimensions.

TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.

In [7]:
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)

test.assertEqual(x0.shape, x0r.shape)

# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
  (cnn): Sequential(
    (0): ReLU()
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Upsample(scale_factor=2, mode=bilinear)
    (3): Conv2d(1024, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Upsample(scale_factor=2, mode=bilinear)
    (7): Conv2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (8): ReLU()
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Upsample(scale_factor=2, mode=bilinear)
    (11): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (12): ReLU()
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Upsample(scale_factor=2, mode=bilinear)
    (15): Conv2d(64, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
)
torch.Size([1, 3, 64, 64])
/home/y.chen/miniconda3/envs/cs236605-hw/lib/python3.7/site-packages/torch/nn/modules/upsampling.py:122: UserWarning: nn.Upsampling is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.Upsampling is deprecated. Use nn.functional.interpolate instead.")
/home/y.chen/miniconda3/envs/cs236605-hw/lib/python3.7/site-packages/torch/nn/functional.py:1961: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))
Out[7]:

Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:

  1. Produce a feature vector $\vec{h}$ from the input image $\vec{x}$.
  2. Use two affine transforms to convert the features into the mean and log-variance of the posterior, i.e. $$ \begin{align}
     \bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
     \log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
    
    \end{align} $$
  3. Use the reparametrization trick to create the latent representation $\vec{z}$.

Note that we model the log of the variance, not the actual variance. The reason is that the log is easier to optimize, since (a) It doesn't have to be positive, and (b) it has a much larger dynamic range. The above formulation is proposed in appendix C of the VAE paper.

TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__().

In [8]:
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)

z, mu, log_sigma2 = vae.encode(x0)

test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)

print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')

# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
    for i in range(500):
        Z[i], _, _ = vae.encode(x0)
        ax.scatter(*Z[i].cpu().numpy())

# Should be close to the above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU()
      (8): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(256, 1024, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (14): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (15): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ReLU()
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Upsample(scale_factor=2, mode=bilinear)
      (3): Conv2d(1024, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU()
      (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Upsample(scale_factor=2, mode=bilinear)
      (7): Conv2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (8): ReLU()
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Upsample(scale_factor=2, mode=bilinear)
      (11): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (12): ReLU()
      (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): Upsample(scale_factor=2, mode=bilinear)
      (15): Conv2d(64, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (ln_u): Linear(in_features=16384, out_features=2, bias=True)
  (ln_logvar): Linear(in_features=16384, out_features=2, bias=True)
  (ln_rec): Linear(in_features=2, out_features=16384, bias=True)
)
mu(x0)=[-0.13604856, -0.20489696], sigma2(x0)=[1.1163194, 0.92313737]
sampled mu tensor([-0.1523, -0.1955])
sampled sigma2 tensor([1.0880, 0.9444])

Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:

  1. Produce a feature vector $\tilde{\vec{h}}$ from the latent vector $\vec{z}$ using an affine transform.
  2. Reconstruct an image $\tilde{\vec{x}}$ from $\tilde{\vec{h}}$.

TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.

In [9]:
x0r = vae.decode(z)

test.assertSequenceEqual(x0r.shape, x0.shape)

Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.

In [10]:
x0r, mu, log_sigma2 = vae(x0)

test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
Out[10]:

Loss Implementation

In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:

$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}) $$

where $d_z$ is the dimension of the latent space. This pointwise loss is the quantity that we'll compute and minimize with gradient descent.

TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.

In [11]:
from hw3.autoencoder import vae_loss
torch.manual_seed(42)

def test_vae_loss():
    # Test data
    N, C, H, W = 10, 3, 64, 64 
    z_dim = 32
    x  = torch.randn(N, C, H, W)*2 - 1
    xr = torch.randn(N, C, H, W)*2 - 1
    z_mu = torch.randn(N, z_dim)
    z_log_sigma2 = torch.randn(N, z_dim)
    x_sigma2 = 0.9
    
    loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
    
    test.assertAlmostEqual(loss.item(), 10.5053434, delta=1e-5)
    return loss

test_vae_loss()
Out[11]:
tensor(10.5053)

Sampling

The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a Normal prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.

TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.

In [12]:
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)

Training

Time to train!

TODO:

  1. Implement the VAETrainer class in the hw3/training.py module.
  2. Tweak the hyperparameters in the part2_vae_hyperparam() function within the hw3/answers.py module.
In [13]:
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']

# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test  = DataLoader(ds_test,  batch_size, shuffle=True)
im_size = ds_train[0][0].shape

# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)

# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)

# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
    return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)

# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show model and hypers
print(vae)
print(hp)
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU()
      (8): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (14): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (15): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ReLU()
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Upsample(scale_factor=2, mode=bilinear)
      (3): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU()
      (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Upsample(scale_factor=2, mode=bilinear)
      (7): Conv2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (8): ReLU()
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Upsample(scale_factor=2, mode=bilinear)
      (11): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (12): ReLU()
      (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): Upsample(scale_factor=2, mode=bilinear)
      (15): Conv2d(64, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
  (ln_u): Linear(in_features=2048, out_features=128, bias=True)
  (ln_logvar): Linear(in_features=2048, out_features=128, bias=True)
  (ln_rec): Linear(in_features=128, out_features=2048, bias=True)
)
{'batch_size': 32, 'h_dim': 128, 'z_dim': 128, 'x_sigma2': 0.9, 'learn_rate': 0.001, 'betas': (0.9, 0.99)}
In [14]:
import IPython.display

def post_epoch_fn(epoch, train_result, test_result, verbose):
    # Plot some samples if this is a verbose epoch
    if verbose:
        samples = vae.sample(n=5)
        fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final
else:
    res = trainer.fit(dl_train, dl_test,
                      num_epochs=200, early_stopping=20, print_every=10,
                      checkpoints=checkpoint_file,
                      post_epoch_fn=post_epoch_fn)
    
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
--- EPOCH 1/200 ---
train_batch (Avg. Loss 0.423, Accuracy 8.5): 100%|██████████| 15/15 [00:04<00:00,  3.60it/s]
test_batch (Avg. Loss 0.419, Accuracy 10.5): 100%|██████████| 2/2 [00:00<00:00,  5.11it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 1
train_batch (Avg. Loss 0.340, Accuracy 10.3): 100%|██████████| 15/15 [00:04<00:00,  3.80it/s]
test_batch (Avg. Loss 0.319, Accuracy 13.1): 100%|██████████| 2/2 [00:00<00:00,  4.88it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 2
train_batch (Avg. Loss 0.303, Accuracy 11.6): 100%|██████████| 15/15 [00:03<00:00,  4.37it/s]
test_batch (Avg. Loss 0.315, Accuracy 13.3): 100%|██████████| 2/2 [00:00<00:00,  6.44it/s]
train_batch (Avg. Loss 0.290, Accuracy 12.1): 100%|██████████| 15/15 [00:03<00:00,  4.35it/s]
test_batch (Avg. Loss 0.288, Accuracy 14.5): 100%|██████████| 2/2 [00:00<00:00,  5.86it/s]
train_batch (Avg. Loss 0.283, Accuracy 12.4): 100%|██████████| 15/15 [00:03<00:00,  4.14it/s]
test_batch (Avg. Loss 0.279, Accuracy 15.0): 100%|██████████| 2/2 [00:00<00:00,  5.80it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 5
train_batch (Avg. Loss 0.272, Accuracy 12.8): 100%|██████████| 15/15 [00:03<00:00,  4.31it/s]
test_batch (Avg. Loss 0.286, Accuracy 14.5): 100%|██████████| 2/2 [00:00<00:00,  6.19it/s]
train_batch (Avg. Loss 0.266, Accuracy 13.2): 100%|██████████| 15/15 [00:03<00:00,  4.36it/s]
test_batch (Avg. Loss 0.268, Accuracy 15.3): 100%|██████████| 2/2 [00:00<00:00,  5.93it/s]
train_batch (Avg. Loss 0.266, Accuracy 13.3): 100%|██████████| 15/15 [00:03<00:00,  4.24it/s]
test_batch (Avg. Loss 0.266, Accuracy 15.2): 100%|██████████| 2/2 [00:00<00:00,  6.16it/s]
train_batch (Avg. Loss 0.265, Accuracy 13.0): 100%|██████████| 15/15 [00:03<00:00,  3.64it/s]
test_batch (Avg. Loss 0.273, Accuracy 14.7): 100%|██████████| 2/2 [00:00<00:00,  4.38it/s]
train_batch (Avg. Loss 0.260, Accuracy 13.5): 100%|██████████| 15/15 [00:03<00:00,  4.41it/s]
test_batch (Avg. Loss 0.271, Accuracy 15.6): 100%|██████████| 2/2 [00:00<00:00,  5.64it/s]
--- EPOCH 11/200 ---
train_batch (Avg. Loss 0.261, Accuracy 13.3): 100%|██████████| 15/15 [00:03<00:00,  4.41it/s]
test_batch (Avg. Loss 0.276, Accuracy 14.8): 100%|██████████| 2/2 [00:00<00:00,  5.84it/s]
train_batch (Avg. Loss 0.263, Accuracy 13.1): 100%|██████████| 15/15 [00:03<00:00,  4.44it/s]
test_batch (Avg. Loss 0.272, Accuracy 15.7): 100%|██████████| 2/2 [00:00<00:00,  6.17it/s]
train_batch (Avg. Loss 0.260, Accuracy 13.4): 100%|██████████| 15/15 [00:03<00:00,  4.66it/s]
test_batch (Avg. Loss 0.272, Accuracy 15.2): 100%|██████████| 2/2 [00:00<00:00,  6.38it/s]
train_batch (Avg. Loss 0.260, Accuracy 13.4): 100%|██████████| 15/15 [00:03<00:00,  4.66it/s]
test_batch (Avg. Loss 0.271, Accuracy 16.1): 100%|██████████| 2/2 [00:00<00:00,  6.81it/s]
train_batch (Avg. Loss 0.257, Accuracy 13.7): 100%|██████████| 15/15 [00:03<00:00,  4.63it/s]
test_batch (Avg. Loss 0.266, Accuracy 15.4): 100%|██████████| 2/2 [00:00<00:00,  6.30it/s]
train_batch (Avg. Loss 0.254, Accuracy 13.9): 100%|██████████| 15/15 [00:03<00:00,  4.62it/s]
test_batch (Avg. Loss 0.270, Accuracy 15.4): 100%|██████████| 2/2 [00:00<00:00,  6.50it/s]
train_batch (Avg. Loss 0.255, Accuracy 13.7): 100%|██████████| 15/15 [00:03<00:00,  4.60it/s]
test_batch (Avg. Loss 0.268, Accuracy 15.2): 100%|██████████| 2/2 [00:00<00:00,  6.63it/s]
train_batch (Avg. Loss 0.251, Accuracy 14.0): 100%|██████████| 15/15 [00:03<00:00,  4.64it/s]
test_batch (Avg. Loss 0.268, Accuracy 15.6): 100%|██████████| 2/2 [00:00<00:00,  6.48it/s]
train_batch (Avg. Loss 0.252, Accuracy 14.1): 100%|██████████| 15/15 [00:03<00:00,  4.62it/s]
test_batch (Avg. Loss 0.260, Accuracy 16.4): 100%|██████████| 2/2 [00:00<00:00,  6.15it/s]
train_batch (Avg. Loss 0.255, Accuracy 13.8): 100%|██████████| 15/15 [00:03<00:00,  4.59it/s]
test_batch (Avg. Loss 0.264, Accuracy 15.6): 100%|██████████| 2/2 [00:00<00:00,  6.54it/s]
--- EPOCH 21/200 ---
train_batch (Avg. Loss 0.250, Accuracy 14.0): 100%|██████████| 15/15 [00:03<00:00,  4.43it/s]
test_batch (Avg. Loss 0.270, Accuracy 15.8): 100%|██████████| 2/2 [00:00<00:00,  6.61it/s]
train_batch (Avg. Loss 0.251, Accuracy 14.0): 100%|██████████| 15/15 [00:03<00:00,  4.65it/s]
test_batch (Avg. Loss 0.258, Accuracy 16.5): 100%|██████████| 2/2 [00:00<00:00,  6.37it/s]
train_batch (Avg. Loss 0.250, Accuracy 14.0): 100%|██████████| 15/15 [00:03<00:00,  4.53it/s]
test_batch (Avg. Loss 0.265, Accuracy 15.5): 100%|██████████| 2/2 [00:00<00:00,  6.27it/s]
train_batch (Avg. Loss 0.249, Accuracy 14.1): 100%|██████████| 15/15 [00:03<00:00,  4.56it/s]
test_batch (Avg. Loss 0.264, Accuracy 15.5): 100%|██████████| 2/2 [00:00<00:00,  6.54it/s]
train_batch (Avg. Loss 0.252, Accuracy 13.9): 100%|██████████| 15/15 [00:03<00:00,  4.59it/s]
test_batch (Avg. Loss 0.282, Accuracy 15.0): 100%|██████████| 2/2 [00:00<00:00,  6.71it/s]
train_batch (Avg. Loss 0.252, Accuracy 13.8): 100%|██████████| 15/15 [00:03<00:00,  4.67it/s]
test_batch (Avg. Loss 0.263, Accuracy 15.4): 100%|██████████| 2/2 [00:00<00:00,  6.53it/s]
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [15]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.

In [16]:
display_answer(hw3.answers.part2_q1)

Your answer: $\sigma^2$ is a hyper-parameter which controls the regularization strength.

Data loss is the difference between the reconstruction picture and the orginal data picture. It's the data-fitting term. The smaller $\sigma^2$ is, the effect of data loss on the final loss is more important. As a result, the generated pictures are more similar to the original pictures and the generalization effect is compromised.

kl_divergence can be interpreted as the information gained by using the posterior is more important than the prior distribution. It's the regularization term. The larger $\sigma^2$ is, the less possibility that there appears over-fitting in training process.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 3: Generative Adversarial Networks

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

We'll use the same data as in Part 2.

But again, to use a custom dataset, edit the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/y.chen/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/y.chen/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/y.chen/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

Generative Adversarial Nets (GANs)

GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.

In a GAN model, two different neural networks compete against each other: A generator and a discriminator.

  • The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.

  • The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

Training GANs

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:

$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.

Model Implementation

We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.

TODO: Implement the Discriminator class in the hw3/gan.py module. If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.

In [6]:
import hw3.gan as gan

dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)

d0 = dsc(x0)
print(d0.shape)

test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
  (feature_encode): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
      (4): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU()
      (8): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Conv2d(256, 1024, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (14): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (15): ReLU()
    )
  )
  (linear): Linear(in_features=16384, out_features=1, bias=True)
)
torch.Size([1, 1])

TODO: Implement the Generator class in the hw3/gan.py module. If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.

In [7]:
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)

z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)

test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
  (ln_rec): Linear(in_features=128, out_features=16384, bias=True)
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ReLU()
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Upsample(scale_factor=2, mode=bilinear)
      (3): Conv2d(1024, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU()
      (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Upsample(scale_factor=2, mode=bilinear)
      (7): Conv2d(256, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (8): ReLU()
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Upsample(scale_factor=2, mode=bilinear)
      (11): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (12): ReLU()
      (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (14): Upsample(scale_factor=2, mode=bilinear)
      (15): Conv2d(64, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    )
  )
)
torch.Size([1, 3, 64, 64])
/home/y.chen/miniconda3/envs/cs236605-hw/lib/python3.7/site-packages/torch/nn/modules/upsampling.py:122: UserWarning: nn.Upsampling is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.Upsampling is deprecated. Use nn.functional.interpolate instead.")
/home/y.chen/miniconda3/envs/cs236605-hw/lib/python3.7/site-packages/torch/nn/functional.py:1961: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode))

Loss Implementation

Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$

  • \mathbb{E} {\bb{x} \sim p(\bb{X}) } \log \Delta {\bb{\delta}}(\bb{x}) \, - \, \mathbb{E} {\bb{z} \sim p(\bb{Z}) } \log (1-\Delta {\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.

GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.

We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.

TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.

In [8]:
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)

y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10

loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)

test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)
/home/y.chen/miniconda3/envs/cs236605-hw/lib/python3.7/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
  warnings.warn(warning.format(ret))

Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$

which can also be seen as a cross-entropy term.

TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.

In [9]:
from hw3.gan import generator_loss_fn
torch.manual_seed(42)

y_generated = torch.rand(20) * 10

loss = generator_loss_fn(y_generated, data_label=1)
print(loss)

test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-5)
tensor(0.0223)

Sampling

Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.

There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients.

TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.

In [10]:
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())

samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)

Training

Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.

As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)

TODO:

  1. Implement the train_batch function in the hw3/gan.py module.
  2. Tweak the hyperparameters in the part3_gan_hyperparam() function within the hw3/answers.py module.
In [11]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

# Loss
def dsc_loss_fn(y_data, y_generated):
    return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return gan.generator_loss_fn(y_generated, hp['data_label'])

# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
print(hp)
{'batch_size': 32, 'z_dim': 32, 'data_label': 1, 'label_noise': 0.5, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0001, 'weight_decay': 0.001}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0001, 'weight_decay': 0.001}}
In [12]:
import IPython.display
import tqdm
from hw3.gan import train_batch

num_epochs = 100

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
    checkpoint_file = checkpoint_file_final

for epoch_idx in range(num_epochs):
    # We'll accumulate batch losses and show an average once per epoch.
    dsc_losses = []
    gen_losses = []
    print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
    
    with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
        for batch_idx, (x_data, _) in enumerate(dl_train):
            x_data = x_data.to(device)
            dsc_loss, gen_loss = train_batch(
                dsc, gen,
                dsc_loss_fn, gen_loss_fn,
                dsc_optimizer, gen_optimizer,
                x_data)
            dsc_losses.append(dsc_loss)
            gen_losses.append(gen_loss)
            pbar.update()

    dsc_avg_loss, gen_avg_loss = np.mean(dsc_losses), np.mean(gen_losses)
    print(f'Discriminator loss: {dsc_avg_loss}')
    print(f'Generator loss:     {gen_avg_loss}')
        
    samples = gen.sample(5, with_grad=False)
    fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
    IPython.display.display(fig)
    plt.close(fig)
--- EPOCH 1/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.18it/s]
Discriminator loss: 0.665591208373799
Generator loss:     2.3431650400161743
--- EPOCH 2/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.11it/s]
Discriminator loss: 0.2863674935172586
Generator loss:     4.083750247955322
--- EPOCH 3/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.08it/s]
Discriminator loss: 0.1929442304022172
Generator loss:     4.673793526256786
--- EPOCH 4/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.99it/s]
Discriminator loss: -0.053004746489665085
Generator loss:     5.5955470029045555
--- EPOCH 5/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.49it/s]
Discriminator loss: -0.015169467996148503
Generator loss:     6.062840798321893
--- EPOCH 6/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.40it/s]
Discriminator loss: 0.019182535655358258
Generator loss:     7.249737234676585
--- EPOCH 7/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.56it/s]
Discriminator loss: -0.11012538380044348
Generator loss:     8.21740088743322
--- EPOCH 8/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: 0.035664471195024604
Generator loss:     8.10509303036858
--- EPOCH 9/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: -0.018163188853684592
Generator loss:     8.468778722426471
--- EPOCH 10/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.61it/s]
Discriminator loss: -0.024105964776347664
Generator loss:     8.386213386760039
--- EPOCH 11/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.56it/s]
Discriminator loss: 0.05831228722544277
Generator loss:     7.85639269211713
--- EPOCH 12/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: 0.11592234835466918
Generator loss:     8.504676145665785
--- EPOCH 13/100 ---
100%|██████████| 17/17 [00:05<00:00,  2.98it/s]
Discriminator loss: -0.04945676160209319
Generator loss:     8.970325554118437
--- EPOCH 14/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.02it/s]
Discriminator loss: -0.10393608021823798
Generator loss:     8.230159871718463
--- EPOCH 15/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.01it/s]
Discriminator loss: 0.20741683945936315
Generator loss:     7.4191092883839325
--- EPOCH 16/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.00it/s]
Discriminator loss: -0.02840136123054168
Generator loss:     9.303709114299101
--- EPOCH 17/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.06it/s]
Discriminator loss: 0.039668788147323275
Generator loss:     8.299091703751508
--- EPOCH 18/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.05it/s]
Discriminator loss: 0.10123807783512508
Generator loss:     6.9044049767886895
--- EPOCH 19/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: 0.07416567622738726
Generator loss:     8.073924232931699
--- EPOCH 20/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.54it/s]
Discriminator loss: 0.09951188108500313
Generator loss:     8.827511338626637
--- EPOCH 21/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.46it/s]
Discriminator loss: -0.032672519412110836
Generator loss:     6.8183479870066925
--- EPOCH 22/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.29it/s]
Discriminator loss: -0.026338994174319154
Generator loss:     8.687284553752226
--- EPOCH 23/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.36it/s]
Discriminator loss: -0.005469306865159203
Generator loss:     9.03015251720653
--- EPOCH 24/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.49it/s]
Discriminator loss: 0.17202730433029287
Generator loss:     9.537726766922894
--- EPOCH 25/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: 0.15245305440005133
Generator loss:     8.151815077837776
--- EPOCH 26/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.11it/s]
Discriminator loss: 0.0770076150622438
Generator loss:     7.722934386309455
--- EPOCH 27/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.02it/s]
Discriminator loss: 0.11119510233402252
Generator loss:     6.263515219968908
--- EPOCH 28/100 ---
100%|██████████| 17/17 [00:06<00:00,  2.83it/s]
Discriminator loss: -0.017324350007316646
Generator loss:     6.13984624077292
--- EPOCH 29/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.05it/s]
Discriminator loss: 0.0380115223062389
Generator loss:     7.720717009376077
--- EPOCH 30/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.06it/s]
Discriminator loss: 0.15291381189051798
Generator loss:     8.588446294560152
--- EPOCH 31/100 ---
100%|██████████| 17/17 [00:06<00:00,  3.31it/s]
Discriminator loss: 0.24672520314069354
Generator loss:     7.837055444717407
--- EPOCH 32/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.59it/s]
Discriminator loss: 0.0960334331235465
Generator loss:     6.942128405851476
--- EPOCH 33/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: 0.08444554621682447
Generator loss:     5.85062697354485
--- EPOCH 34/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.58it/s]
Discriminator loss: 0.005215063691139221
Generator loss:     9.078670698053697
--- EPOCH 35/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.55it/s]
Discriminator loss: -0.010717516874565798
Generator loss:     7.793759963091682
--- EPOCH 36/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.61it/s]
Discriminator loss: 0.15688405974822886
Generator loss:     8.344012344584746
--- EPOCH 37/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.44it/s]
Discriminator loss: 0.046054218621814955
Generator loss:     6.3486772986019355
--- EPOCH 38/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: 0.10395300234941875
Generator loss:     6.421184034908519
--- EPOCH 39/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.46it/s]
Discriminator loss: 0.14793037469772732
Generator loss:     6.004163223154404
--- EPOCH 40/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.42it/s]
Discriminator loss: 0.21047393221627264
Generator loss:     6.4160480779760025
--- EPOCH 41/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.59it/s]
Discriminator loss: 0.12729758352917783
Generator loss:     8.197409643846399
--- EPOCH 42/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.54it/s]
Discriminator loss: 0.0952007376095828
Generator loss:     6.107181675293866
--- EPOCH 43/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: 0.147759259404505
Generator loss:     6.783079820520737
--- EPOCH 44/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.47it/s]
Discriminator loss: 0.06832839614328216
Generator loss:     7.6897071389591
--- EPOCH 45/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.55it/s]
Discriminator loss: -0.08225327686351888
Generator loss:     8.949194431304932
--- EPOCH 46/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.43it/s]
Discriminator loss: 0.23957223346566453
Generator loss:     9.321126068339629
--- EPOCH 47/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.58it/s]
Discriminator loss: -0.00029547161915723016
Generator loss:     6.626183734220617
--- EPOCH 48/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: -0.014370116898242165
Generator loss:     6.057347297668457
--- EPOCH 49/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.56it/s]
Discriminator loss: -0.005339245147564832
Generator loss:     9.92268565121819
--- EPOCH 50/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.49it/s]
Discriminator loss: 0.03893971969099606
Generator loss:     8.609769119935876
--- EPOCH 51/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.44it/s]
Discriminator loss: 0.096056257319801
Generator loss:     6.730131485882928
--- EPOCH 52/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.58it/s]
Discriminator loss: 0.08519158950623344
Generator loss:     5.740005184622372
--- EPOCH 53/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: 0.03644390404224396
Generator loss:     6.692349966834573
--- EPOCH 54/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.54it/s]
Discriminator loss: 0.0598052207599668
Generator loss:     7.62394764844109
--- EPOCH 55/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.58it/s]
Discriminator loss: 0.03218997751965242
Generator loss:     8.718012683531818
--- EPOCH 56/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.59it/s]
Discriminator loss: 0.11019610219142016
Generator loss:     9.26663485695334
--- EPOCH 57/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: 0.03265893092269406
Generator loss:     8.66035673197578
--- EPOCH 58/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: 0.0040100063471233145
Generator loss:     8.009502663331872
--- EPOCH 59/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: -0.13879421397167094
Generator loss:     9.3203044498668
--- EPOCH 60/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.58it/s]
Discriminator loss: -0.03605940444942783
Generator loss:     8.582344952751608
--- EPOCH 61/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.63it/s]
Discriminator loss: 0.08603168585721184
Generator loss:     8.202338302836699
--- EPOCH 62/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: 0.1287456227137762
Generator loss:     6.461713594548843
--- EPOCH 63/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: -0.1399045805720722
Generator loss:     9.011455031002269
--- EPOCH 64/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.44it/s]
Discriminator loss: -0.10060867973986794
Generator loss:     15.94050816928639
--- EPOCH 65/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: -0.0044250242850359745
Generator loss:     7.120770510505228
--- EPOCH 66/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.10it/s]
Discriminator loss: 0.18552147465593674
Generator loss:     8.751095126656924
--- EPOCH 67/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.38it/s]
Discriminator loss: 0.10677134464768802
Generator loss:     7.543191236608169
--- EPOCH 68/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.08it/s]
Discriminator loss: 0.347675128675559
Generator loss:     7.601417737848618
--- EPOCH 69/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: -0.026023197042591432
Generator loss:     6.73066837647382
--- EPOCH 70/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.55it/s]
Discriminator loss: 0.10128249732010505
Generator loss:     7.389584737665513
--- EPOCH 71/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: -0.030069149154074052
Generator loss:     8.218639738419476
--- EPOCH 72/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: 0.05669032475527595
Generator loss:     7.831343370325425
--- EPOCH 73/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.49it/s]
Discriminator loss: 0.10164742566206876
Generator loss:     8.54612263511209
--- EPOCH 74/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: 0.027234021574258804
Generator loss:     10.128244119531969
--- EPOCH 75/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.31it/s]
Discriminator loss: 0.2724829386262333
Generator loss:     12.497187684564029
--- EPOCH 76/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.11it/s]
Discriminator loss: -0.020117213182589588
Generator loss:     13.319724812227136
--- EPOCH 77/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.47it/s]
Discriminator loss: 0.05115117907852811
Generator loss:     8.474001547869515
--- EPOCH 78/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.54it/s]
Discriminator loss: 0.033022546592880696
Generator loss:     7.444225423476276
--- EPOCH 79/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: -0.05359134801170405
Generator loss:     8.438564384684843
--- EPOCH 80/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: 0.08757808519636884
Generator loss:     7.029635611702414
--- EPOCH 81/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: -0.0722563617369708
Generator loss:     8.983774381525377
--- EPOCH 82/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.56it/s]
Discriminator loss: -0.03539140506044907
Generator loss:     8.86016009835636
--- EPOCH 83/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: 0.1279124429120737
Generator loss:     11.456101165098303
--- EPOCH 84/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.62it/s]
Discriminator loss: 0.1243582137805574
Generator loss:     6.261760010438807
--- EPOCH 85/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.69it/s]
Discriminator loss: 0.039988479193519146
Generator loss:     9.145460212931914
--- EPOCH 86/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.60it/s]
Discriminator loss: 0.009376979268649045
Generator loss:     7.371882943546071
--- EPOCH 87/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: -0.13725585578119054
Generator loss:     9.266328755547018
--- EPOCH 88/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: 0.021769460290670395
Generator loss:     9.903851705438951
--- EPOCH 89/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.53it/s]
Discriminator loss: 0.10036879339638878
Generator loss:     11.025622087366441
--- EPOCH 90/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.47it/s]
Discriminator loss: 0.13979519268169122
Generator loss:     8.969646594103645
--- EPOCH 91/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.38it/s]
Discriminator loss: 0.022841988240971285
Generator loss:     10.262443921145271
--- EPOCH 92/100 ---
100%|██████████| 17/17 [00:05<00:00,  2.81it/s]
Discriminator loss: 0.05740730157669853
Generator loss:     12.123285742367015
--- EPOCH 93/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.52it/s]
Discriminator loss: -0.17896830202902064
Generator loss:     13.873863472658044
--- EPOCH 94/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.47it/s]
Discriminator loss: 0.10434690424624611
Generator loss:     11.158083747414981
--- EPOCH 95/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.44it/s]
Discriminator loss: 0.011593268855529673
Generator loss:     12.158823013305664
--- EPOCH 96/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.47it/s]
Discriminator loss: 0.014874162481111638
Generator loss:     11.237943761488971
--- EPOCH 97/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.56it/s]
Discriminator loss: 0.08443098252310473
Generator loss:     14.037939464344698
--- EPOCH 98/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.45it/s]
Discriminator loss: 0.040293567978283935
Generator loss:     8.355806238511029
--- EPOCH 99/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.50it/s]
Discriminator loss: 0.06208418309688568
Generator loss:     8.385891858269186
--- EPOCH 100/100 ---
100%|██████████| 17/17 [00:05<00:00,  3.43it/s]
Discriminator loss: 0.15436158055330024
Generator loss:     6.914724924985101
In [13]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [14]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?

In [15]:
display_answer(hw3.answers.part3_q1)

Your answer: The objective of the discriminator is to learn from this supplied dataset how to distinguish real from fake signals. During this part of GAN training, only the discriminator parameters are updated.

When the fake signal is presented to the discriminator, naturally it will be classified as fake with a label close to 0.0. The optimizer computes generator parameter updates based on the presented label (that is, 1.0) and its own prediction to take into account this new training data. In other words, the discriminator has some level of doubts about its prediction and GAN takes that into consideration. This time, GAN will let the gradients back propagate from the last layer of the discriminator down to the first layer of the generator. However, in most practices, during this phase of training, the discriminator parameters are temporarily frozen. The generator uses the gradients to update its parameters and improve its ability to synthesize fake signals.

Question 2

  1. When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?

  2. What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?

In [16]:
display_answer(hw3.answers.part3_q2)

Your answer:

1.No, we shouldn't stop training solely based on the fact that the Generator loss is below the threshold. When the generator loss is small, the generated fake pictures is more similar to the real pictures and thus more likely to confuse the discriminator. However, the ultimate goal of GAN is to reduce discriminator loss while decrease the generator loss at the same time in order to get a discriminator which distinguishes the fake pictures which are very similar to the real pictures

2.The discriminator got too strong relative to the generator. Beyond this point, the generator finds it almost impossible to fool the discriminator, hence the increase in it's loss.

Question 2

Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?

In [17]:
display_answer(hw3.answers.part3_q3)

Your answer:

The GAN generates better pictures than VAE's blurry output. However, VAE learns hidden representation of data better.

Autoencoders learn a given distribution comparing its input to its output, this is good for learning hidden representations of data, but is pretty bad for generating new data. Mainly because we learn an averaged representation of the data thus the output becomes pretty blurry.

Generative Adversarial Networks take an entirely different approach. They use another network (so-called Discriminator) to measure the distance between the generated and the real data. Basically what it does is distinguishing the real data from the generated. It receives some data as an input and returns a number between 0 and 1. 0 meaning the data is fake and 1 meaning it is real. The generators goal then is learning to convince the Discriminator into believing it is generating real data.